import numpy as np
import json
import os
import matplotlib.pyplot as plt
from tqdm import tqdm


def main():
    threshold = 0.5
    #path = "./imagenet/val"
    path = "/scratch/shared/beegfs/chrisr/adiwol"

    with open(os.path.join(path, "open_images_labels.txt"), "r") as label_file:
        label_map = json.load(label_file)

    labels = ["" for i in range(len(label_map.keys()))]
    for _, label in label_map.items():
        labels[label["id"]-1] = label["name"]

    num_classes = np.zeros((len(label_map.keys()),), dtype=np.int64)

    with open(os.path.join(path, "open_images_results.txt"), "r") as results_file:
        for line in tqdm(results_file):
            detections = json.loads(line)
            classes = np.asarray(detections["classes"], dtype=np.int64)
            scores = np.asarray(detections["scores"])
            classes = classes[scores > threshold]
            for c in classes:
                num_classes[c-1] += 1

    sort_idx = np.argsort(num_classes, )[::-1]
    #sort_idx = sort_idx[:20]
    labels = [labels[i] for i in sort_idx]

    x = np.arange(len(labels))
    fig, ax = plt.subplots(figsize=(200, 40))
    ax.set_ylabel('Number of objects')
    ax.set_title(path)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation='vertical')
    ax.bar(x, num_classes[sort_idx])
    plt.savefig('num_objects.png', dpi=96)
    plt.show()


if __name__ == "__main__":
    main()
